# Load R packages  ----
setup_plgha <- function(tidylog = TRUE){
  suppressPackageStartupMessages({
    library(here)
    library(tabulator)
    library(janitor)
    library(gtsummary)
    library(ipumsr)
    library(fixest)
    library(broom)
    library(labelled)
    library(dotwhisker)
    library(hrbrthemes)
    library(survey)
    library(corrplot)
    library(PerformanceAnalytics)
    library(ggsci)
    library(scales)
    library(xtable)
    library(modelsummary)
    library(data.table)
    library(patchwork)
    library(ggrepel)
    library(ggtext)
    library(marginaleffects)
    library(sf)
    library(haven)
    library(tidyverse)
    library(progress)
    if(tidylog){library(tidylog)}
  }) 
  
  # set filepaths
  rootFolder <- here()
  clean <- here("clean")
  temp <- here("temp")
  scripts <-  here("scripts")
}

# Operators ----

# not in function
`%notin%` <- Negate(`%in%`)

# Select the 8 core methods on all SDP surveys
core8 <- function(dat){
  dat %>% 
    select(
      CON = starts_with("CON"),
      CYCB = starts_with("CYCB"),
      EMRG = starts_with("EMRG"),
      FC = starts_with("FC"),
      IMP = starts_with("IMP"),
      IUD = starts_with("IUD"),
      PILL = starts_with("PILL"),
      INJANY
    ) %>% 
    varlbl_suffix(
      CON = "Male condoms",
      CYCB = "Cycle beads",
      EMRG = "Emergency contraception",
      FC = "Female condoms",
      IMP = "Implants",
      IUD = "IUDs",
      PILL = "Pill",
      INJANY = "Injectables"
    )
}

# Selection helpers for long and short methods 
shorts <- function(){all_of(c('CON', 'CYCB', 'FC', 'EMRG', 'PILL'))}
longs <- function(){all_of(c('IMP', 'IUD', 'INJANY'))}
injs <- function(){c(
  starts_with("INJ"),
  starts_with("DEPO"),
  starts_with("SAY")
)}

# Variable Labeling Functions ---- 

# paste a string to the end of an existing variable label (sep = ": ")
varlbl_suffix <- function(.data, .strict = FALSE, ...){
  values <- rlang::dots_list(...)
  if(length(values) > 0) {
    if(.strict & !all(names(values) %in% names(.data))){
      stop("some variables not found in .data")
    }
    for(v in intersect(names(values), names(.data))){
      var_label(.data[[v]]) <- paste0(var_label(.data[[v]]), ": ", values[[v]])
    } 
  }
  .data
}

# paste a string to the beginning of an existing variable label (sep = ": ")
varlbl_prefix <- function(.data, .strict = FALSE, ...){
  values <- rlang::dots_list(...)
  if(length(values) > 0) {
    if(.strict & !all(names(values) %in% names(.data))){
      stop("some variables not found in .data")
    }
    for(v in intersect(names(values), names(.data))){
      var_label(.data[[v]]) <- paste0(values[[v]], var_label(.data[[v]]))
    } 
  }
  .data
}

# Labels regularly-used SDP summary variables 
core_lbl <- function(dat, note = NULL, ...){
  if(!is.null(note)){
    note <- paste0(" (", note, ")")
  }
  
  dat %>% 
    varlbl_suffix(
      .strict = FALSE,
      CON = paste0("Male Condoms", note),
      CYCB = paste0("Cycle beads", note),
      EMRG = paste0("Emergency methods", note),
      FC = paste0("Female Condoms", note),
      IMP = paste0("Implants", note),
      IUD = paste0("IUDs", note),
      PILL = paste0("Pills", note),
      INJANY = paste0("Injectables", note),
      RHY = paste0("Rhythm", note), 
      WD = paste0("Withdrawal", note),
      LAM = paste0("LAM", note),
      TRADOTH = paste0("Other traditional methods", note),
      SHORT = paste0("Any short-acting methods", note),
      SHORTNUM = paste0("Number of short-acting methods", note),
      LONG = paste0("Any LARCs", note),
      LONGNUM = paste0("Number of LARCs", note),
      CORE8 = paste0("Any core methods", note),
      CORE8NUM = paste0("Number of core methods", note),
      ...
    )
}

# Modeling Functions ---- 

# All DiD models for woman-level data
wmn_did <- function(data, clustervar, ..., groupvar = NULL, basic = FALSE){
  yvars <- rlang::enquos(...)
  clustervar <- rlang::enquo(clustervar)
  
  # Prepare progress bar 
  pb <- progress_bar$new(
    "Generating models: [:bar] :percent eta: :eta",
    total = ifelse(
      is.null(substitute(groupvar)), 
      length(yvars), 
      length(yvars) * 2
    ),
    clear = FALSE, 
    width= 100
  )
  
  data %>%
    # group output by `groupvar` if provided
    group_by({{groupvar}}) %>%
    
    # All `feols` models go here, (`.x` is the predicted outcome)
    summarise(
      .groups = "keep",
      across(
        c(!!!yvars),
        ~{pb$tick()
          suppressMessages(
            list( 
              "PLGHA - Controlled" = feols(
                .x ~ plgha_on*exposure +  
                  age + wealtht + educattgen + married + urban | 
                  country + year,
                ~!!clustervar,
                data = cur_data()
              ),
              "PLGHA - Uncontrolled" = feols(
                .x ~  plgha_on*exposure | country + year,
                ~!!clustervar,
                data = cur_data()
              )
            ) %>% 
              append(if(!basic){
                list(
                  "GDiD - Controlled" = feols(
                    .x ~ i(year, exposure, ref = 2016, ref2 = "Low") +
                      age + wealtht + educattgen + married + urban | 
                      country + year,
                    ~!!clustervar,
                    data = cur_data()
                  ),
                  "PLGHA - Log Aid" = feols(
                    .x ~ plgha_on*ln_aidpc +
                      age + wealtht + educattgen + married + urban |
                      country + year,
                    ~!!clustervar,
                    data = cur_data()
                  ),
                  "PLGHA - US Aid % of Domestic Health Spending" = feols(
                    .x ~ plgha_on*exposure_ghes + 
                      age + wealtht + educattgen + married + urban |
                      country + year,
                    ~!!clustervar,
                    data = cur_data()
                  ),
                  "PLGHA - FPRH Aid" = feols(
                    .x ~ plgha_on*exposure_fprh + 
                      age + wealtht + educattgen + married + urban |
                      country + year,
                    ~!!clustervar,
                    data = cur_data()
                  )
                )
              })
          )
        }
      ) 
    ) %>% 
    ungroup() %>% 
    # Tidy the model results, but keep only the interaction terms
    # Each model should have a name (`model`) and time measure (`periodicity`)
    pivot_longer(
      !{{groupvar}}, # ignore groupvar if provided
      names_to = "outcome",
      values_to = "results"
    ) %>%
    mutate(
      model = names(results),
      label = map_chr(outcome, ~labelled::var_label(data[[.x]])),
      results = map(
        results, ~.x %>%
          tidy(conf.level = 0.95, conf.int = TRUE) %>%
          filter(str_detect(term, ":"))
      )
    ) %>%
    separate(model, sep = " - ", c("periodicity", "model")) %>%
    relocate(outcome, model, periodicity, results, label) %>%
    unnest(results) %>%
    did_format()
}


# All DiD models for SDP-level data
sdp_did <- function(data, clustervar, ..., groupvar = NULL, basic = FALSE){
  yvars <- rlang::enquos(...)
  clustervar <- rlang::enquo(clustervar)
  
  # Prepare progress bar 
  pb <- progress_bar$new(
    "Generating models: [:bar] :percent eta: :eta",
    total = ifelse(
      is.null(substitute(groupvar)), 
      length(yvars), 
      length(yvars) * 2
    ),
    clear = FALSE, 
    width= 100
  )
  
  data %>%
    # group output by `groupvar` if provided
    group_by({{groupvar}}) %>%
    
    # All `feols` models go here, (`.x` is the predicted outcome)
    summarise(
      .groups = "keep",
      across(
        c(!!!yvars),
        function(pack){
          pb$tick()
          pack <- pack %>% 
            imap_dfr(
              function(y, varname){suppressMessages({
                cur_data() %>% 
                  mutate(.x = y) %>% 
                  summarise(
                    models = list(
                      "PLGHA - Controlled" = feols(
                        .x ~ plgha_on*exposure + 
                          public + facilitytype_3 + has_electrc + 
                          urban + has_water | 
                          country + year, 
                        ~!!clustervar,
                        data = cur_data()
                      ),
                      "PLGHA - Uncontrolled" = feols(
                        .x ~ plgha_on*exposure | country + year, 
                        ~!!clustervar,
                        data = cur_data()
                      )
                    ) %>% 
                      append(if(!basic){
                        list(
                          "GDiD - Controlled" = feols(
                            .x ~ i(year, exposure, ref = 2016, ref2 = "Low") + 
                              public + facilitytype_3 + has_electrc + 
                              urban + has_water |
                              country + year, 
                            ~!!clustervar,
                            data = cur_data()
                          ),
                          "PLGHA - Log Aid" = feols(
                              .x ~ plgha_on*ln_aidpc + 
                                  public + facilitytype_3 + has_electrc + 
                                  urban + has_water |
                                  country + year, 
                              ~!!clustervar,
                              data = cur_data()
                          ),
                          "PLGHA - US Aid % of Domestic Health Spending" = feols(
                            .x ~  exposure_ghes*plgha_on +
                              public + facilitytype_3 + has_electrc + 
                              urban + has_water |
                              country + year, 
                            ~!!clustervar,
                            data = cur_data()
                          ),
                          "PLGHA - FPRH Aid" = feols(
                            .x ~  exposure_fprh*plgha_on +
                              public + facilitytype_3 + has_electrc + 
                              urban + has_water |
                              country + year, 
                            ~!!clustervar,
                            data = cur_data()
                          )
                        )
                      }),
                    varname = varname,
                    label = labelled::var_label(y)
                  )
              })
              })
          list(pack)
        }
      )
    ) %>% 
    ungroup() %>% 
    # Tidy the model results, but keep only the interaction terms
    # This looks a bit different from `wmn_did` because results are packed 
    pivot_longer(
      !{{groupvar}}, # ignore groupvar if provided
      names_to = "outcome",
      values_to = "results"
    ) %>%
    unnest(results) %>% 
    mutate(
      model = names(models),
      outcome = paste0(outcome, "_", varname)
    ) %>% 
    select(-c(varname)) %>% 
    mutate(
      models = map(
        models, ~.x %>%
          tidy(conf.level = 0.95, conf.int = TRUE) %>%
          filter(str_detect(term, ":"))
      )
    ) %>%
    separate(model, sep = " - ", c("periodicity", "model")) %>%
    relocate(outcome, model, periodicity, models, label) %>%
    unnest(models) %>%
    did_format()
}

did_format <- function(did_results){
  did_results %>% 
    # Parse each interaction term and format text as `period` and `exposure`
    separate(
      term,
      sep = "(?<=[[:alnum:]]):(?=[[:alnum:]])",
      c("period", "exposure")
    ) %>%
    mutate(
      period = case_when(
        period == "plgha_onTRUE" ~ "PLGHA On",
        str_detect(period, "year") ~ suppressWarnings(
          parse_number(period) %>% as.character
        )
      ),
      exposure = case_when(
        str_detect(exposure, "exposure") &
        str_detect(exposure, "High") ~ "High exposure",
        str_detect(exposure, "aidpc$") ~ "Per capita aid"
      )
    ) 
}

# Leave-One-Out (by country) models
wmn_loo <- function(data, outcome){
  set.seed(123)
  outcome <- rlang::enquo(outcome)
  countries <- unique(data$country)
  
  # Step 1 - Generate LOO splits (by country)
  loo_splits <- data %>% 
    rsample::group_vfold_cv(group = "country")
  
  # Step 2 - Estimate base model with each LOO split
  loo_splits$splits %>% 
    map_df(~{
      .x <- as_tibble(.x) %>% mutate(y = !!outcome)
      
      # What country is left out? `loo_c`
      countries_split <- unique(.x["country"]$country)
      loo_c <- countries[countries %notin% countries_split]
      
      # Model
      mod <- feols(
        y ~
          exposure*plgha_on + age + wealtht + educattgen + married + urban |
          country + year, ~eaid,
        data = .x
      )
      
      # Tidy the output with `loo_c`
      mod %>%
        tidy(conf.level = 0.95, conf.int = TRUE) %>%
        filter(str_detect(term, "exposureHigh:")) %>%
        mutate(left_out = loo_c)
    })
}


sdp_loo <- function(data, outcome){
  set.seed(123)
  outcome <- rlang::enquo(outcome)
  countries <- unique(data$country)
  
  # Step 1 - Generate LOO splits (by country)
  loo_splits <- data %>% 
    rsample::group_vfold_cv(group = "country")
  
  # Step 2 - Estimate base model with each LOO split
  loo_splits$splits %>% 
    map_df(~{
      .x <- as_tibble(.x) %>% mutate(y = !!outcome)
      
      # What country is left out? `loo_c`
      countries_split <- unique(.x["country"]$country)
      loo_c <- countries[countries %notin% countries_split]
      
      # Model
      mod <- feols(
        y ~
          exposure*plgha_on + public + facilitytype_3 + has_electrc + 
          urban + has_water |
          country + year, ~eaid,
        data = .x
      )
      
      # Tidy the output with `loo_c`
      mod %>%
        tidy(conf.level = 0.95, conf.int = TRUE) %>%
        filter(str_detect(term, "exposureHigh:")) %>%
        mutate(left_out = loo_c)
    })
}

# Plotting Functions ---- 

# plot results of standard DiD (with ExposureXPLGHAOn and different sets of
# controls)
did_plot <- function(models_df, title = NULL, subtitle = NULL){
  models_df %>% 
    ggplot(aes(x = label, y = estimate, color = model, group = model)) +
    coord_flip() + 
    plgha_theme(title, subtitle)
}

# plot results of GDiD
gdid_plot <- function(models_df, title = NULL, subtitle = NULL){
  models_df %>% 
    ggplot(aes(x = period, y = estimate, color = model, group = model)) +
    geom_line(position = position_dodge(width=0.5)) +
    geom_vline(
      aes(xintercept = 3.5),
      colour = "black"
    ) +
    facet_wrap(
      ~label, 
      scales = "free", 
      labeller = labeller(label = label_wrap_gen(20))
    ) + 
    plgha_theme(title, subtitle)
}


plgha_theme <- function(title = NULL, subtitle = NULL){
  components <- list(
    theme_minimal() %+replace% 
      theme(
        axis.text.y = element_text(size = 10),
        axis.text.x = element_text(size = 10),
        legend.text = element_text(size = 12),
       # axis.title.y = element_blank(),
        axis.ticks.y = element_blank(),
       # panel.grid.major = element_blank(),
       # panel.grid.minor = element_blank(),
        strip.text = element_text(size = 13),
        axis.title.x = element_text(size = 13),
        legend.position = 'bottom'
      ),
    labs(
      title = title,
      subtitle = subtitle,
      x = NULL,
      y = "Change in Outcome (p.p)",
      color = NULL
    ),
    scale_y_continuous(breaks = pretty_breaks(6)),
    scale_x_discrete(
      labels = function(x) str_wrap(x, width = 25),
      limits = rev),
    scale_color_npg(labels = function(x) str_wrap(x, width = 15)),
    geom_hline(
      aes(yintercept = 0),
      colour = "grey60",
      linetype = 2
    ),
    geom_point(
      size = 1,
      position = position_dodge(width = 0.65)
    ),
    geom_errorbar(
      aes(ymin = `conf.low`, ymax = `conf.high`),
      width = 0.15,
      size = 0.35,
      position = position_dodge(width = 0.65)
    )
  )
}



